# Save the final merged dataframe
result <- read.csv("results/benchmarking/deseq/comp_adj/parallel_models/merged.csv", row.names = 1)
pal <- c("#969696",
         "#fcc5c0", "#fa9fb5", "#f768a1", "#c51b8a",
         "#addd8e", "#78c679", "#31a354", "#006837",
         "#fd8d3c","#e6550d"
         )

# Modify legend text 

models.legend <- list(
  unadjusted    = ~ "Group",
  raw_cm        = ~ "Group + CM prop",
  raw_cm_fb     = ~ "Group + CM prop + 1x prop of minor cells",
  raw_cm_fb_im  = ~ "Group + CM prop + 2x prop of minor cells",
  raw_4         = ~ "Group + CM prop + 3x prop of minor cells",
  clr_cm        = ~ "Group + CLR of CMs",
  clr_cm_fb     = ~ "Group + CLR of CMs + 1x CLR of minor cells",
  clr_cm_fb_im  = ~ "Group + CLR of CMs + 2x CLR of minor cells",
  clr_4         = ~ "Group + CLR of CMs + 3x CLR of minor cells",
  pc1           = ~ "Group + PC1",
  pc2           = ~ "Group + PC1 + PC2"
  )

p.fp <- result |>
ggplot(aes(x = sub_sublist_name, y = false_pos +0.01, color = sublist_name)) +
  geom_point(alpha = 0.4) + # smaller and more transparent points
  geom_smooth(se = F, method = "loess", size = 1.5, alpha = 0.2, span = 0.3) +
  #scale_y_continuous(trans = "log2") +
  scale_color_manual(values = pal,
                     name = "Model design",
                     labels = models.legend,
                     breaks = names(models.legend)) +
  theme_minimal() +
  theme(
    legend.position = "none",
    legend.title = element_blank(),
    text = element_text(size = 12),
    panel.grid.major = element_blank(), # remove major gridlines
    panel.grid.minor = element_blank()  # remove minor gridlines
  ) +
  #geom_magnify(from = c(-0.07, 0.07,-3,30), to = c(-0.08,0.08, 380,600), 
  #                 shape = "rect", shadow = F)  +
  labs(y= "False Positives in 500 genes", x = "Simulated CM proportion difference") +
  ggtitle("False positive rate in simulated dataset\n(25mil counts, mean major prop is 0.5, 5 replicates)")

p.tp <- result |>
ggplot(aes(x = sub_sublist_name, y = true_pos + 0.01, color = sublist_name, group = sublist_name)) +
  geom_point(alpha = 0.2) + # smaller and more transparent points
  geom_smooth(se = F, method = "loess", size = 1.5, alpha = 0.2, span = 0.3) +
  #scale_y_continuous(trans = "log10") +
  scale_color_manual(values = pal,
                     name = "Model design",
                     labels = models.legend,
                     breaks = names(models.legend)) +
  theme_minimal() +
  theme(
    legend.position = "none",
    legend.title = element_blank(),
    text = element_text(size = 12),
    panel.grid.major = element_blank(), # remove major gridlines
    panel.grid.minor = element_blank()  # remove minor gridlines
  ) +
  #geom_magnify(from = c(-0.07, 0.07,-3,30), to = c(-0.08,0.08, 380,600), 
  #                 shape = "rect", shadow = F)  +
  labs(y= "True Positives out of 100 DE genes", x = "Simulated CM proportion difference") +
  ggtitle("True positive rate (2x fold change)")
# Want to take the ratio of the false postive rate over the true positive rate
result <- result %>% 
  mutate(distance_bin = floor(abs(sub_sublist_name) / 0.05) * 0.05)
 
rates.df <- result %>%
  mutate(
    false_neg = total_pos - true_pos,
    true_neg = 500 - (true_pos + false_pos + false_neg),
    Precision = true_pos / (true_pos + false_pos),
    Recall = true_pos / (true_pos + false_neg),
    "F1 score" = 2 * (Precision * Recall) / (Precision + Recall),
    Specificity = true_neg / (true_neg + false_pos)
  )
# Melt the data for plotting
rates.df <- rates.df %>%
  pivot_longer(cols = c(Precision, Recall, "F1 score", Specificity),
               names_to = "Metric",
               values_to = "Value") %>%
  group_by(Metric, sublist_name) %>%
  mutate(median_value = median(Value, na.rm = TRUE) ) %>% # 
  arrange(Metric, median_value)
p.stats <- ggplot(rates.df, aes(x = sublist_name, y = Value)) +
  geom_boxplot(aes(fill = sublist_name)) +
  scale_fill_manual(values = pal,
                     name = "Model design",
                     labels = models.legend,
                     breaks = names(models.legend)) +
  facet_wrap(~ Metric, scales = "free") +
  theme_minimal() +
  theme(
    legend.position = "right",
    strip.background = element_blank(),
    strip.text.x = element_text(size = 12),
    axis.text.x = element_text(size = 7, angle = 30)
  ) +
  labs(
    title = "Performance Metrics by Model Type",
    y = "Metric Value",
    x = "Model and Sublist"
  ) +
  scale_x_discrete(limits = unique(rates.df$sublist_name))

# Filter the data for F1 score
result_f1 <- rates.df %>% 
  filter(Metric == "F1 score")

# Create the scatterplot
p.f1 <- ggplot(result_f1, aes(x = abs(sub_sublist_name), y = mean_value)) +
  geom_point(aes(color = sublist_name), size = 4, alpha = 0.2) +
  geom_smooth(aes(color = sublist_name), se = F, method = "loess", size = 1.5, alpha = 0.3, span = 0.3) +
  scale_color_manual(values = pal,
                     name = "Model design",
                     labels = models.legend,
                     breaks = names(models.legend)) +
  theme_minimal() +
  labs(
    title = "F1 Score by Absolute Percent Difference",
    x = "Absolute Percent Difference",
    y = "F1 Score"
  )


design <- c("AC
             BC")
wrap_plots(A = p.fp, B = p.tp, C = p.stats, design = design)
## `geom_smooth()` using formula = 'y ~ x'
## `geom_smooth()` using formula = 'y ~ x'

pal <- c("#969696",
         "#fcc5c0", "#fa9fb5", "#f768a1", "#c51b8a",
         "#c6dbef", "#9ecae1", "#6baed6", "#3182bd",
         "#fd8d3c","#e6550d")

# Modify legend text 

models.legend <- list(
  unadjusted    = ~ "Unadjusted",
  raw_cm        = ~ "CM prop",
  raw_cm_fb     = ~ "CM prop + 1x",
  raw_cm_fb_im  = ~ "CM prop + 2x",
  raw_4         = ~ "CM prop + 3x",
  clr_cm        = ~ "CLR of CMs",
  clr_cm_fb     = ~ "CLR of CMs + 1x",
  clr_cm_fb_im  = ~ "CLR of CMs + 2x",
  clr_4         = ~ "CLR of CMs + 3x",
  pc1           = ~ "PC1",
  pc2           = ~ "PC1 + PC2"
  )

p.fp <- result |> 
        subset(ModelType != "Isometric log ratios") |>
ggplot(aes(x = sub_sublist_name, y = false_pos +0.01, color = sublist_name)) +
  geom_point(alpha = 0.1) + # smaller and more transparent points
  geom_smooth(se = F, method = "loess", size = 1.5, alpha = 0.2, span = 0.8) +
  #scale_y_continuous(trans = "log2") +
  scale_color_manual(values = pal,
                     name = "Model design",
                     labels = models.legend,
                     breaks = names(models.legend)) +
  theme_minimal() +
  theme(
    legend.position = "none",
    legend.title = element_blank(),
    text = element_text(size = 24),
    panel.grid.major = element_blank(), # remove major gridlines
    panel.grid.minor = element_blank()  # remove minor gridlines
  ) +
  #geom_magnify(from = c(-0.07, 0.07,-3,30), to = c(-0.08,0.08, 380,600), 
  #                 shape = "rect", shadow = F)  +
  labs(y= "False Positives in 500 genes", x = "Simulated CM proportion difference") +
  ggtitle("False Positive Rate")

p.tp <- result |> 
        subset(ModelType != "Isometric log ratios") |>
ggplot(aes(x = sub_sublist_name, y = true_pos + 0.01, color = sublist_name, group = sublist_name)) +
  geom_point(alpha = 0.1) + # smaller and more transparent points
  geom_smooth(se = F, method = "loess", size = 1.5, alpha = 0.2, span = 0.7) +
  #scale_y_continuous(trans = "log10") +
  scale_color_manual(values = pal,
                     name = "Model design",
                     labels = models.legend,
                     breaks = names(models.legend)) +
  theme_minimal() +
  theme(
    legend.position = "none",
    legend.title = element_blank(),
    text = element_text(size = 24),
    panel.grid.major = element_blank(), # remove major gridlines
    panel.grid.minor = element_blank(),
    legend.key.width = unit(3,"cm"),
    legend.key.height =unit(1,"cm"), # remove minor gridlines
  ) +
  #geom_magnify(from = c(-0.07, 0.07,-3,30), to = c(-0.08,0.08, 380,600), 
  #                 shape = "rect", shadow = F)  +
  labs(y= "True Positives out of 100 DE genes", x = "Simulated CM proportion difference") +
  ggtitle("True Positive Rate") 

# Modify the ggplot for False Positives (p.fp) to include dodging by 'distance_bin'
p.fp.bin <- result |> 
  subset(ModelType != "Isometric log ratios") |>
  ggplot(aes(x = sub_sublist_name, y = false_pos + 0.01, color = sublist_name)) +
  geom_point(aes(dodge = distance_bin), position = position_dodge(0.2), alpha = 0.1) + # dodging
  geom_smooth(se = F, method = "loess", size = 1.5, alpha = 0.2, span = 0.8) +
  scale_color_manual(values = pal,
                     name = "Model design",
                     labels = models.legend,
                     breaks = names(models.legend)) +
  theme_minimal() +
  theme(
    legend.position = "none",
    text = element_text(size = 24),
    panel.grid.major = element_blank(),
    panel.grid.minor = element_blank()
  ) +
  labs(y= "False Positives in 500 genes", x = "Simulated CM proportion difference") +
  ggtitle("False Positive Rate")


p.stats <- rates.df |> 
        subset(ModelType != "Isometric log ratios" &
               Metric == "F1 score") |> 
  ggplot(aes(x = factor(sublist_name, level = names(models.legend)), 
             y = Value)) +
  geom_boxplot(aes(fill = sublist_name)) +
  geom_jitter(width = 0.1, height = 0, alpha = 0.1) +
  scale_fill_manual(values = pal,
                     name = "Model design",
                     labels = models.legend,
                     breaks = names(models.legend)) +
  #facet_wrap(~ distance_bin) +
  theme_minimal() +
  theme(
    legend.position = "bottom",
    axis.text.x = element_blank(),
    text = element_text(size = 24)
  ) +
  labs(
    title = "Model Accuracy",
    y = "F1 Score (higher is more accurate)",
    x = "Model Type"
  ) 

design <- c("AB
             CC")
wrap_plots(A = p.fp, B = p.tp, C = p.stats, design = design)
## `geom_smooth()` using formula = 'y ~ x'
## `geom_smooth()` using formula = 'y ~ x'

 rates.df |> 
        subset(ModelType != "Isometric log ratios") |> 
  ggplot(aes(x = factor(sublist_name, level = names(models.legend)), 
             y = Value)) +
  geom_boxplot(aes(fill = sublist_name)) +
  geom_jitter(width = 0.1, height = 0, alpha = 0.1) +
  scale_fill_manual(values = pal,
                     name = "Model design",
                     labels = models.legend,
                     breaks = names(models.legend)) +
  facet_wrap(~ Metric, scales = "free") +
  theme_minimal() +
  theme(
    legend.position = "bottom",
    axis.text.x = element_blank(),
    text = element_text(size = 24)
  ) +
  labs(
    title = "Model Accuracy",
    y = "F1 Score (higher is more accurate)",
    x = "Model Type"
  )